import sys,os,logging,argparse,json,numpy

def output_time(s, indent, obj):
    if isinstance(obj, list):
        st = numpy.array(obj[::2])
        ed = numpy.array(obj[1::2])
        dur = (ed - st).mean()
        return f"{dur:7.5f} s"
    elif isinstance(obj, dict):
        for k,v in obj.items():
            s = s + "\n" + '\t' * indent + f"{k:15}:"
            s = s + output_time("", indent + 1, v)
        return s
    elif isinstance(obj, float):
        return f"{obj:7.5f} s"

def get_args():
    parser = argparse.ArgumentParser()

    parser_world = parser.add_argument_group('world')
    parser_world.add_argument('--local_rank', type=int, default=-1, help='Local rank passed by torch.distributed.launch')
    # parser_world.add_argument('--cuda_device', "-C", type=str, default="0,1", help='cuda device index, can be multiple as "0,1,2"')
    parser_world.add_argument('--exp_group', "-G", type=str, default="", help='The group name of the experiment')
    parser_world.add_argument('--action', "-A", type=str, default="f", help='The function to be carried out')
    
    return parser.parse_args()

def write_args(args, logger):
    logger.info(f"-------------- args --------------")
    for arg, value in vars(args).items():
        logger.info(f"{arg}: {value}")
    logger.info(f"-------------- args --------------")

def simplify_keys(dic):
    new_dic = {}
    for k,v in dic.items():
        if isinstance(k, str):
            k = k.split(".")[-1]
        new_dic[k] = v
    return new_dic

def write_group_log(dic, group_log_path):
    with open(group_log_path, "a") as f:
        dic = simplify_keys(dic)
        dic = clean_dict(dic)
        s = pprint(dic, True)
        s = s + "\n"
        f.write(s)

def clean_dict(dic):
    new_dic = {}
    for k,v in dic.items():
        if isinstance(k, str):
            k = k.strip(" '\t\n,.")
        if isinstance(v, str):
            v = v.strip(" '\t\n,.")
        new_dic[k] = v
    return new_dic

def generate_path(args):
    # REUSE
    base_path = os.path.join(os.path.abspath(os.path.dirname(os.path.dirname(__file__))), "outputs")
    base_path = os.path.join("/root/autodl-tmp", "outputs") # for temporary storage concern
    exp_group = args.exp_group
    group_path = os.path.join(base_path,exp_group)
    os.makedirs(group_path, exist_ok = True)

    args_dict = args.to_dict()
    args_dict["local_rank"] = 0
    sorted_description = sort_dict(args_dict)

    exp_path = os.path.join(group_path, f"{args.action}-{args.Data.Name}-{string_hash(sorted_description)}")
    os.makedirs(exp_path, exist_ok = True)
    
    trash_path = os.path.join(exp_path,"trash.log")
    log_path = os.path.join(exp_path,"log.log")
    img_path = os.path.join(exp_path, "img")
    model_path = os.path.join(exp_path, "model")
    os.makedirs(img_path, exist_ok = True)
    os.makedirs(model_path, exist_ok = True)

    group_log_path = os.path.join(group_path,"log.log")
    
    print(trash_path, "\n", log_path, "\n", img_path, "\n", model_path, "\n", group_log_path)
    
    return trash_path, log_path, img_path, model_path, group_log_path

import torch, os, numpy, re, random
from hashlib import md5
import torch.cuda as cuda
import torch.distributed as distributed

def ptime(second = True):
    if second:
        return datetime.datetime.now().strftime("%d %b. %H:%M:%S")
    else:
        return datetime.datetime.now().strftime("%d %b. %H:%M")
        
def pprint(obj, with_title = False):
    if isinstance(obj,dict):
        if with_title:
            titlelist = [k.__repr__()[:20] for k in obj.keys()]
            valuelist = [v.__repr__()[:20] for v in obj.values()]
            maxlen = max([len(s) for s in titlelist + valuelist]) + 1
            s = ["".join([k.rjust(maxlen) for k in titlelist])]
            s.append("".join([v.rjust(maxlen) for v in valuelist]))
            s = "\n".join(s)
        else:
            valuelist = [v.__repr__()[:20] for v in obj.values()]
            maxlen = max([len(s) for s in valuelist]) + 1
            s = "".join([v.rjust(maxlen) for v in valuelist])
    elif isinstance(obj,(list,tuple)):
        if with_title:
            titlelist = [k.__repr__()[:20] for k in obj[::2]]
            valuelist = [v.__repr__()[:20] for v in obj[1::2]]
            maxlen = max([len(s) for s in titlelist + valuelist]) + 1
            s = ["".join([k.rjust(maxlen) for k in titlelist])]
            s.append("".join([v.rjust(maxlen) for v in valuelist]))
            s = "\n".join(s)
        else:
            valuelist = [v.__repr__()[:20] for v in obj]
            maxlen = max([len(s) for s in valuelist]) + 1
            s = "".join([v.rjust(maxlen) for v in valuelist])
    return s 

def gpu_summary(device = None):
    return cuda.memory_summary(device)

def str_get(string,module,**kwargs):
    if string in vars(module):
        return vars(module)[string](**kwargs)
    else:
        raise NotImplementedError

def module_structure(file):
    return os.path.relpath(file).split(".")[0].replace("/",".")

def sort_dict(obj):
    if isinstance(obj, list):
        return tuple(sorted(obj))
    elif isinstance(obj, dict):
        tolist = []
        for k in sorted(obj.keys()):
            tolist.append((k,sort_dict(obj[k])))
        return tuple(tolist)
    else:
        return obj

def string_hash(obj):
    return md5(obj.__repr__().encode()).hexdigest()

def establish_communication(**kwargs):
    # # For spawn
    # os.environ['MASTER_ADDR'] = 'localhost'
    # os.environ['MASTER_PORT'] = '12355'
    # distributed.init_process_group("nccl", rank=rank, world_size=args.num_gpus)
    cuda.set_device(kwargs["device"])
    distributed.init_process_group("nccl",init_method='env://')

def recurrent_iter(obj):
    while True:
        for item in obj:
            yield item

def myround(tens, dec = 3):
    tens = tens * (10 ** dec)
    tens = tens.round()
    tens = tens / (10 ** dec)
    return tens

def printtensor(tensor):
    out = " "
    for t in tensor.tolist():
        out  = out + f",{t:.3f} "
    return out

def image_stat(img):
    img = img.detach().cpu()
    return f"{img.shape},mean {printtensor(img.mean(dim = (0,2,3)))}" \
        f"std {printtensor(img.std(dim = (0,2,3)))}" \
            f"max {printtensor(img.amax(dim = (0,2,3)))}" \
                f"min {printtensor(img.amin(dim = (0,2,3)))}"

def randomness_control(seed):
    print("seed",seed)
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def merge_dict(*dicts):
    merged_dict = {}
    for d in dicts:
        merged_dict.update(d)
    return d